import torch
import torch.nn as nn
import math
import smooth_dp_utils


class VariationalLayer(nn.Module):
    """
    Class to create BNN Layers
    """
    def __init__(self, 
                 input_size, output_size,
                 prior_mu, prior_rho,
                 n_samples, dev, mu_init_1=-0.2, mu_init_2=0.2, rho_init=-5
                ):
        super().__init__()
        
        self.dev = dev
        
        # Bias weight
        input_size = input_size + 1
        
        # Defining Prior distribution (Gaussian)
        self.prior_mu = torch.tensor(prior_mu).to(dev)
        self.prior_rho = torch.tensor(prior_rho).to(dev)
        
        # Defining Variational class (Gaussian class)
        self.theta_mu = nn.Parameter(
            torch.Tensor(input_size, output_size).to(dev).uniform_(
                mu_init_1, mu_init_2)).float()
        self.theta_rho = nn.Parameter(
            torch.Tensor(input_size, output_size).to(dev).uniform_(
                rho_init, rho_init+1)).float()      
        
        # Defining some constants
        self.logsqrttwopi = torch.log(
            torch.sqrt(2*torch.tensor(math.pi))).to(dev)
        self.K = torch.tensor(1).to(dev)
        
        # Defining number of samples for forward
        self.n_samples = n_samples

    
    def rho_to_sigma(self, rho):
        return torch.log(1 + torch.exp(rho))

    def sample_weight(self):
        w = (self.theta_mu.to(self.dev)
        + self.rho_to_sigma(self.theta_rho.to(self.dev))*torch.randn(
            (self.n_samples, self.theta_mu.shape[0], self.theta_mu.shape[1])
        ).to(self.dev))
        return w

    def log_prob_gaussian(self, x, mu, rho):
            return (
                - self.logsqrttwopi
                - torch.log(self.rho_to_sigma(rho))
                - ((x - mu)**2)/(2*self.rho_to_sigma(rho)**2)
            ).sum(axis=[1, 2]).mean()
    
    def prior(self, w):
        return self.log_prob_gaussian(
            w, self.prior_mu, self.prior_rho)
        
    def variational(self, w):
        return self.log_prob_gaussian(
            w, self.theta_mu, self.theta_rho) 
    
    def kl_divergence_layer(self):
        w = self.sample_weight()
        Q = self.variational(w)
        P = self.prior(w)
        KL = Q - P
        return KL
    
    def forward(self, x_layer):
        w = self.sample_weight().to(self.dev)    
        x_next_layer = torch.bmm(x_layer.to(self.dev), w[:, :-1, :]) + w[:,-1,:].unsqueeze(1)
        return x_next_layer
    
    

class ANN(nn.Module):
    def __init__(self, input_size, output_size, n_hidden_layers=3, hl_sizes=[64, 64]):
        super().__init__()   
        
        self.act1 = nn.ReLU()
        self.linear1 = nn.Linear(input_size, hl_sizes[0])
        self.linear2 = nn.Linear(hl_sizes[0], hl_sizes[1])
        self.linear3 = nn.Linear(hl_sizes[1], hl_sizes[1])
        self.linear4 = nn.Linear(hl_sizes[1], output_size)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        x = self.act1(x)
        x = self.linear3(x)
        x = self.act1(x)
        y_avg = self.linear4(x)
        y_avg = y_avg.unsqueeze(0)

        return y_avg 
    
    

class ANNClass(nn.Module):
    def __init__(self, input_size, output_size, n_hidden_layers=3, hl_sizes=[64, 64]):
        super().__init__()   
        
        self.act1 = nn.ReLU()
        self.linear1 = nn.Linear(input_size, hl_sizes[0])
        self.linear2 = nn.Linear(hl_sizes[0], hl_sizes[1])
        self.linear3 = nn.Linear(hl_sizes[1], hl_sizes[1])
        self.linear4 = nn.Linear(hl_sizes[1], output_size)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        x = self.act1(x)
        x = self.linear3(x)
        x = self.act1(x)
        x = self.linear4(x)

        z = torch.sigmoid(x)
                
        return z
    
     
    
class ANNVar(nn.Module):
    def __init__(self, input_size, output_size, n_hidden_layers=3, hl_sizes=[64, 64]):
        super().__init__()   
        
        self.act1 = nn.ReLU()
        self.linear1 = nn.Linear(input_size, hl_sizes[0])
        self.linear2 = nn.Linear(hl_sizes[0], hl_sizes[1])
        self.linear3 = nn.Linear(hl_sizes[1], hl_sizes[1])
        self.linear3A = nn.Linear(hl_sizes[1], hl_sizes[1])
        self.linear4 = nn.Linear(hl_sizes[1], output_size)
        self.linear4_2 = nn.Linear(hl_sizes[1], output_size)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        x = self.act1(x)
        x = self.linear3(x)
        x = self.act1(x)
        x = self.linear3A(x)
        x = self.act1(x)
        y_avg = self.linear4(x)
        y_std = self.linear4_2(x)
        
        #y_avg = y_avg.unsqueeze(0)
        #y_std = y_std.unsqueeze(0)

        return y_avg, y_std
    
    
class VariationalANN(nn.Module):
    def __init__(self, n_samples, input_size, output_size, n_hidden_layers=3, hl_sizes=[64, 64], plv=1, dev='cpu'):
        super().__init__()
        self.output_type_dist = True
        self.n_samples = n_samples
        self.act1 = nn.ReLU()
        # Hidden layer sizes
         
        mu_init = 0.2
        rho_init=-5
        self.linear1 = VariationalLayer(input_size, hl_sizes[0], 0, plv, n_samples, dev, -mu_init, mu_init, rho_init)
        self.linear2 = VariationalLayer(hl_sizes[0], hl_sizes[1], 0, plv, n_samples, dev, -mu_init, mu_init, rho_init)
        self.linear3 = VariationalLayer(hl_sizes[1], hl_sizes[1], 0, plv, n_samples, dev, -mu_init, mu_init, rho_init)
        self.linear4 = VariationalLayer(hl_sizes[1], output_size, 0, plv, n_samples, dev, -mu_init, mu_init, rho_init)

        self.neurons = (
            (input_size+1)*hl_sizes[0] 
            + (hl_sizes[0]+1)*hl_sizes[1]
            + (hl_sizes[1]+1)*hl_sizes[1]
            + 2*(hl_sizes[1]+1)*output_size
        )

    def forward(self, x):
        x = torch.unsqueeze(x, 0)
        x = x.expand((self.n_samples, x.shape[1], x.shape[2]))
        x = self.linear1(x)
        x = self.act1(x)

        x = self.linear2(x)
        x = self.act1(x)

        x = self.linear3(x)
        x = self.act1(x)

        y_avg = self.linear4(x)
        #rho = self.linear4_2(x)
        return y_avg 
    
    
class CombinedModel(nn.Module):
    def __init__(self, model_A, model_B, model, prior_A, prior_B, 
                 M_ind_original_A, M_ind_original_B, M_ind_original_bridges,
                 sort_prev_A, sort_new_A, sort_prev_B, sort_new_B,
                 nodes_original_all
                ):
        super(CombinedModel, self).__init__()

        self.model = model
        self.model_A = model_A
        self.model_B = model_B
        
        self.prior_A = prior_A
        self.prior_B = prior_B
        
        self.M_ind_A = M_ind_original_A
        self.M_ind_B = M_ind_original_B
        self.M_ind_bridges = M_ind_original_bridges
        
        self.sort_new_A = sort_new_A
        self.sort_new_B = sort_new_B
        
        self.sort_prev_A = sort_prev_A
        self.sort_prev_B = sort_prev_B

    def forward(self, x):
        
        prior_A = self.prior_A.unsqueeze(0).unsqueeze(0).detach()
        prior_B = self.prior_B.unsqueeze(0).unsqueeze(0).detach()
        
        dY_comb, _ = self.model(x)
        dY_A, _ = self.model_A(x)
        dY_B, _ = self.model_B(x)
        
        dY_A[:,:,self.sort_new_A] = dY_A[:,:,self.sort_prev_A]
        dY_B[:,:,self.sort_new_B] = dY_B[:,:,self.sort_prev_B]

        factor_A = dY_comb[:,:,-2].unsqueeze(-1)
        factor_B = dY_comb[:,:,-1].unsqueeze(-1)
        dY_bridges = dY_comb[:,:,:-2]
        
        Y_A = (dY_A.detach() + prior_A).clip(0.001, None)
        Y_B = (dY_B.detach() + prior_B).clip(0.001, None)
        Y_A_new = (factor_A + 1.).clip(0.01, None) * (Y_A.detach())        
        Y_B_new = (factor_B + 1.).clip(0.01, None) * (Y_B.detach())
        
        dY_A_new = Y_A_new - prior_A.detach()
        dY_B_new = Y_B_new - prior_B.detach()
                   
        dY_combined = torch.cat([dY_A_new, dY_B_new, dY_bridges], dim=-1)

        return dY_combined, None
    
    
    
class CombinedModel_2(nn.Module):
    def __init__(self, models_single, model, prior_distance, 
                 nodes_treshold, M_indices_clust, M_indices_bridge, costs_to_matrix):
        super(CombinedModel_2, self).__init__()

        self.model = model
        
        self.costs_to_matrix = costs_to_matrix
        
        self.nodes_treshold = nodes_treshold
        
        self.M_indices_bridge = M_indices_bridge
        self.M_indices_clust = M_indices_clust
        
        self.models_s = []
        for i in range(0, len(models_single)):
            self.models_s.append(models_single[i])
        
        self.prior_distance = prior_distance.unsqueeze(0)
        
        self.n_clusters = len(models_single)
                
        #self.dY_inners = []
        #self.Y_inners = []
        
                    
        #self.prior_A = prior_A
        #self.prior_B = prior_B
        
        

    def forward(self, x, priors_inner_solved, evalu=False):
            
        #prior_A = self.prior_A.unsqueeze(0).unsqueeze(0).detach()
        #prior_B = self.prior_B.unsqueeze(0).unsqueeze(0).detach()
        
        dY_comb, _ = self.model(x)       

        factors = dY_comb[:,-self.n_clusters:]   
        
        dY_shape = (dY_comb.shape[1]-self.n_clusters)//2
        
        dY_bridges = dY_comb[:,:dY_shape]
        dsigmaY_bridges = dY_comb[:,dY_shape:-self.n_clusters]
        
        if evalu:
            dY_inners = []
            dsigmaY_inners = []
            for i in range(0, len(self.models_s)):
                dY_inners_, dsigmaY_inners_ = self.models_s[i](x)
                #Y_inners_ = (dY_inners_.detach() + prior_A).clip(0.001, None).detach()
                #self.Y_inners.append(Y_inners_)

                #self.M_Y_pred_after_inner = self.costs_to_matrix(
                        #self.M_Y_pred_after_inner, self.M_indices_clust[i], dY_inners_)

                dY_inners.append(dY_inners_)
                dsigmaY_inners.append(dsigmaY_inners_)

            return None, None, dY_bridges, dsigmaY_bridges, factors, dY_inners, dsigmaY_inners
        
        #import pdb
        #pdb.set_trace()
        M_pred = self.prior_distance.repeat(x.shape[0], 1, 1)
        
        for i in range(0, len(self.models_s)):
             M_pred[:,
                 self.nodes_treshold[i]+1:self.nodes_treshold[i+1]+1,
                 self.nodes_treshold[i]+1:self.nodes_treshold[i+1]+1
             ] = ((factors[:,i] + 1.).clip(0.001, None).unsqueeze(-1))*(priors_inner_solved[i])
          
        M_pred = self.costs_to_matrix(M_pred.squeeze(), self.M_indices_bridge, dY_bridges)
        M_pred_sigma = self.costs_to_matrix(0.2*M_pred.squeeze(), self.M_indices_bridge, dsigmaY_bridges)
        

        return M_pred, M_pred_sigma, dY_bridges, dsigmaY_bridges, factors, None, None